import torch
import torch.nn.functional as F
from torch_geometric.nn import GATConv, GlobalAttention, global_mean_pool
from torch_geometric.nn.conv.gatv2_conv import GATv2Conv
from torch_geometric.data import DataLoader
from torch_geometric.datasets import TUDataset
from torch_geometric.utils import from_networkx
from torch_geometric.utils import degree
import networkx as nx
import pandas as pd
from torch_geometric.data import Data, InMemoryDataset, Batch
from tqdm import tqdm
import os
from modifydata import MyGraphDataset,FilteredGraphDataset
from torch_geometric.utils import subgraph
import random
from torch.cuda.amp import autocast, GradScaler

from torch_geometric.nn import GINConv, global_add_pool
from torch_geometric.nn import GlobalAttention
from torch_geometric.datasets import TUDataset
from torch.optim import Adam
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KDTree
import numpy as np
from sklearn.preprocessing import StandardScaler, MinMaxScaler
from torch.optim.lr_scheduler import ReduceLROnPlateau
import traceback

import torch
from torch_geometric.data import Data

class SplitDataset:
    def __init__(self, dataset_ori, dataset_pos, dataset_neg):
        assert len(dataset_ori) == len(dataset_pos) == len(dataset_neg), "All datasets must have the same length"
        
        self.dataset_ori = dataset_ori
        self.dataset_pos = dataset_pos
        self.dataset_neg = dataset_neg
        
        # 从原始数据集中提取 num_features
        if len(self.dataset_ori) > 0 and hasattr(self.dataset_ori[0], 'x'):
            self.num_features = self.dataset_ori[0].x.shape[1]  # 获取节点特征维度
        else:
            raise AttributeError("No valid data with 'x' attribute found to infer num_features.")

    def __len__(self):
        return len(self.dataset_ori)

    def __getitem__(self, idx):
        ori = self.dataset_ori[idx]
        pos = self.dataset_pos[idx]
        neg = self.dataset_neg[idx]

        return ori, pos, neg



device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
root = '/data/liuruiheng/neural-graph/GNN'

# loaded_dataset = MyGraphDataset(root = root)
dataset = FilteredGraphDataset(max_edges=500000, root=root)
scaler = StandardScaler()
dataset.x = scaler.fit_transform(dataset.x)
dataset_ori = dataset[::3]
dataset_pos = dataset[1::3]
dataset_neg = dataset[2::3]

dataset = SplitDataset(dataset_ori,dataset_pos,dataset_neg)

num_graphs = len(dataset)
train_size = int(0.6 * num_graphs)
val_size = int(0.1 * num_graphs)
test_size = num_graphs - train_size - val_size

train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, val_size, test_size])

print(len(train_dataset))
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)
test_loader = DataLoader(test_dataset, batch_size=32)

print(len(train_loader))
# print(len(test_loader),len(train_loader))
# for x in val_loader:
#     print(x[0],type(x[0]),x[0].num_nodes)


class GINEmbedding(torch.nn.Module):
    def __init__(self, num_features, embedding_dim, hidden_dim, num_layers):
        super(GINEmbedding, self).__init__()
        self.convs = torch.nn.ModuleList()
        self.batch_norms = torch.nn.ModuleList()

        # 新的第一层来处理不同的输入维度
        self.input_proj = torch.nn.Linear(num_features, hidden_dim)

        for i in range(num_layers):
            nn = torch.nn.Sequential(torch.nn.Linear(hidden_dim, hidden_dim), torch.nn.ReLU(), torch.nn.Linear(hidden_dim, hidden_dim))
            conv = GINConv(nn, train_eps=True)
            bn = torch.nn.BatchNorm1d(hidden_dim)

            self.convs.append(conv)
            self.batch_norms.append(bn)

        self.embedding = torch.nn.Linear(hidden_dim, embedding_dim)
        # self.attention_net = torch.nn.Sequential(
        #     torch.nn.Linear(hidden_dim, 1),  # 输入是 hidden_dim，输出是1
        #     torch.nn.Sigmoid()  # 使用 Sigmoid 激活函数
        # )

        # # 使用 GlobalAttention 进行全局池化
        # self.global_attention = GlobalAttention(gate_nn=self.attention_net)

    def forward(self, x, edge_index, batch):
        # 使用新的输入投影层（如果有的话）
        if hasattr(self, 'input_proj'):
            x = self.input_proj(x)
            x = F.relu(x)

        for conv, bn in zip(self.convs, self.batch_norms):
            x = conv(x, edge_index)
            x = F.relu(x)
            x = bn(x)

        x = global_mean_pool(x, batch)
        # x = self.global_attention(x, batch)
        x = self.embedding(x)
        return x.squeeze()  # 确保输出是 [batch_size, embedding_dim]

def getModel():
    model = GINEmbedding(num_features=dataset.num_features, embedding_dim=128, hidden_dim=64, num_layers=3).to(device)
    return model

# 创建新的嵌入模型
embedding_model = GINEmbedding(num_features=dataset.num_features, embedding_dim=128, hidden_dim=64, num_layers=3).to(device)

# 加载预训练权重
# pretrained_dict = torch.load('gin_model.pth')
# model_dict = embedding_model.state_dict()

# # 只加载匹配的层
# pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict and v.shape == model_dict[k].shape}
# model_dict.update(pretrained_dict) 
# embedding_model.load_state_dict(model_dict, strict=False)

# 定义优化器
optimizer = Adam(filter(lambda p: p.requires_grad, embedding_model.parameters()), lr=0.01)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=10, verbose=True)

# 确保模型在正确的设备上
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
embedding_model = embedding_model.to(device)

def contrastive_loss(H_o, H_p, H_n):
    # 计算origin和positive的差距
    mse_loss_op = torch.mean((H_o - H_p)**2, dim=1)  # 对每个batch进行mse
    O_s = 1 / (1 + torch.exp(-mse_loss_op))  # 相似性

    # 计算origin和negative的差距
    mse_loss_on = torch.mean((H_o - H_n)**2, dim=1)  # 对每个batch进行mse
    O_d = 1 / (1 + torch.exp(-mse_loss_on))  # 相似性

    # 计算对比损失
    class_loss = O_s - O_d
    return class_loss.mean()  # 取batch平均值作为最终loss

def train(model, train_loader, optimizer, device):
    model.train()  # 设置模型为训练模式
    total_loss = 0.0
    
    for batch in train_loader:
        origin, positive, negative = batch

        # 将输入数据转移到GPU或CPU
        origin = origin.to(device)
        positive = positive.to(device)
        negative = negative.to(device)

        # 前向传播获取embedding
        H_o = model(origin.x, origin.edge_index, origin.batch)  # [batch_size, embedding_dim]
        H_p = model(positive.x, positive.edge_index, positive.batch)  # [batch_size, embedding_dim]
        H_n = model(negative.x, negative.edge_index, negative.batch)  # [batch_size, embedding_dim]

        # 计算损失
        loss = contrastive_loss(H_o, H_p, H_n)

        # 反向传播和优化
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(train_loader)
    

    print(f'Training Loss: {avg_loss:.4f}')
    
    return avg_loss

def evaluate(model, val_loader):
    model.eval()  # 设置模型为评估模式
    total_loss = 0.0
    with torch.no_grad():  # 禁用梯度计算
        for batch in val_loader:
            origin, positive, negative = batch
            
            # 将输入数据转移到GPU或CPU
            origin = origin.to(device)
            positive = positive.to(device)
            negative = negative.to(device)

            # 前向传播获取embedding
            H_o = model(origin.x, origin.edge_index, origin.batch)  # [batch_size, embedding_dim]
            H_p = model(positive.x, positive.edge_index, positive.batch)  # [batch_size, embedding_dim]
            H_n = model(negative.x, negative.edge_index, negative.batch)  # [batch_size, embedding_dim]

            # 计算损失
            loss = contrastive_loss(H_o, H_p, H_n)
            total_loss += loss.item()

    avg_loss = total_loss / len(val_loader)
    if avg_loss < 0:
        avg_loss = -1 * avg_loss
    print(f'Validation Loss: {avg_loss:.4f}')
    return avg_loss



# 训练和验证
num_epochs = 96
for epoch in tqdm(range(num_epochs), desc="Training Progress", unit="epoch"):
    # print(f"\nEpoch {epoch+1}/{num_epochs}")
    # 训练模型
    train(embedding_model, train_loader, optimizer, device)
    # 验证模型
    evaluate(embedding_model, val_loader)




# 在测试集上评估
test_loss = evaluate(embedding_model, test_loader)
print(f'Test Loss: {test_loss:.4f}')
torch.save(embedding_model.state_dict(), './gin_finetune_model_new.pth')